{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3.2 线性回归的从零开始实现\n",
    "\n",
    "在了解了线性回归的背景知识之后，现在我们可以动手实现它了。尽管强大的深度学习框架可以减少大量重复性工作，但若过于依赖它提供的便利，会导致我们很难深入理解深度学习是如何工作的。因此，本节将介绍如何只利用`Tensor`和`autograd`来实现一个线性回归的训练。\n",
    "\n",
    "首先，导入本节中实验所需的包或模块，其中的matplotlib包可用于作图，且设置成嵌入显示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192\n",
      "  return f(*args, **kwds)\n",
      "/root/anaconda3/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
      "  return f(*args, **kwds)\n",
      "/root/anaconda3/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
      "  return f(*args, **kwds)\n",
      "/root/anaconda3/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
      "  return f(*args, **kwds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)\n",
    "from matplotlib import pyplot as plt\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.1 生成数据集\n",
    "\n",
    "我们构造一个简单的人工训练数据集，它可以使我们能够直观比较学到的参数和真实的模型参数的区别。设训练数据集样本数为1000，输入个数（特征数）为2。给定随机生成的批量样本特征 $\\boldsymbol{X} \\in \\mathbb{R}^{1000 \\times 2}$，我们使用线性回归模型真实权重 $\\boldsymbol{w} = [2, -3.4]^\\top$ 和偏差 $b = 4.2$，以及一个随机噪声项 $\\epsilon$ 来生成标签\n",
    "$$\n",
    "\\boldsymbol{y} = \\boldsymbol{X}\\boldsymbol{w} + b + \\epsilon\n",
    "$$\n",
    "其中噪声项 $\\epsilon$ 服从均值为0、标准差为0.01的正态分布。噪声代表了数据集中无意义的干扰。下面，让我们生成数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inputs = 2\n",
    "num_examples = 1000\n",
    "true_w = [2, -3.4]\n",
    "true_b = 4.2\n",
    "features = tf.random.normal((num_examples, num_inputs), stddev=1)\n",
    "labels = true_w[0] * features[:,0] + true_w[1] * features[:, 1] + true_b\n",
    "labels += tf.random.normal(labels.shape, stddev=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意，`features`的每一行是一个长度为2的向量，而`labels`的每一行是一个长度为1的向量（标量）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([-0.622976    0.31687057], shape=(2,), dtype=float32) tf.Tensor(1.871589, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(features[0], labels[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过生成第二个特征`features[:, 1]`和标签 `labels` 的散点图，可以更直观地观察两者间的线性关系。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f0cc841c3d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAACnCAYAAADqrEtMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2dfXRU53ngf68+xugDhGYkZJBkfYMrXCwbYQhgO1g4jbMsTnMWUufsmmTTQ3L2xCaOTzd1S5uk68bZ7RKXeHdrSBsHnzZucOvGlIY2IBNAYBQElinIyNLoA0lgeTQjBMwMHkm8+8ed++rOaEYa0Egaofd3jo5GM3fuvKNzn/s87/MppJRoNJrEI2m6F6DRaCKjhVOjSVC0cGo0CYoWTo0mQdHCqdEkKFo4NZoEJWU6PjQnJ0cWFxdPx0drNAnF6dOn+6SUuZFemxbhLC4upqGhYTo+WqNJKIQQndFe02atRpOgaOHUaBIULZwaTYKS0MLpdF3nK6/9Bqfr+nQvRaOZchJaOF/c38ThZhcv7m+a7qVoNFPOtHhrY2X7hkqgKfhbo5ldJLRwluVm8tpXHpruZWg000JCm7UazWxGC6dGk6Bo4dRoEhQtnBpNgqKF8xbxeAPsOuLE4w1M91I0dzgxC6cQ4idCiI+FEOcsz31XCNEjhGgM/nxucpaZOLzZ0MVLBy7wZkPXdC9Fc4dzK6GUnwL/B3g97PmXpZT/O24rSnA2VReG/NZoJouYNaeU8ijgmcS1KBLZdLRn2Pjao2XYM2zTvRTNHU489pzfEEKcDZq92XE4X1TTMZGFVqOJNxMVzr8CyoAq4DKwI9qBQoitQogGIUSDy+Ua86Sbqgt54Yl7R5mOer+nmU1MKH1PStlrPhZC/BjYP8axu4HdANXV1WO2mTdNx3Ama7/n8QZ4s6GLTdWF2lzVJAwT0pxCiIWWP38XOBft2HgwWfs9rZE1iUjMmlMI8QbwaSBHCNENfAf4tBCiCpBAB/C1SVjjpDOWRtZaVTNdxCycUsqnIjz9N3Fcy7QRzYyGEa0KRD1Go5kMErpkLBHQcU3NdKHT98ZhrH2uDu1oJhMtnBMgVkfSRIRY3wBmL9qsnQCxOpImsm/Ve97Zy4wUzvE8qPHwsFrPAUQ8X6yOpInsW/Wed/YyI4VzPG0SD22z50QHO2tb8AWGSbcl3/L5rEI1lhCPx0Teq5nZzEjhHE+b3Iq2ia5lpfp9O9prPKHS8VPNeMxI4Rzvwr8VbRNNy25ZXRJ8JEa9Fg/0XlIzHjNSOCdCuMZaX5nHyTY36yvzRh17uvMKda19ADz3+OKo57gd9F5SMx6zTjhNjeULDJFuS8EXGOJws4tVpb2UPZoZcpwpmCMmroF1P2oV2nA83gB7TnQAki2rS2J2Jt0K2jy+c5l1wmlqKl9gmJcOXGBbTUXE8rRN1YX4AsP4A8OAIQQjF78M+x2ZNxu62FnbAkC6LWVSzFdtHt+5zDrhNDWWxxsg3ZYcVePYM2w89/hidh1x8tKBCyHCtWV1Cem2lHFNUlPArU6lWIlVI2rz+M5FSDn23X8yqK6ulok+2doUjvWVeRxq6o1bzDTWc5g3hReeuPe2NKI2d2cGQojTUsrqSK/NOs0ZK+OZi5Eu/mjPPb+3kcPNrqjnisRENaI2d2c+s144o2mY8VLzIglcuEBYj1u3JFedKxatZs+wqdS/29F+2tyd+cz6xPdoyetjVaPsOdHB4WYXxY50VhTbVWL6+so81i3JVWGZNxu6ONzsYk2Zg2UF8y3vb+elAxfYc6I9amK7Kdi326FBdwmc+dxKJ4SfABuAj6WU9wWfswM/B4oxOiFsllL2x3+Zk0c0DWOEQdoBwZbVxRGzhzrcPl55p0VpUCAkLGP1DO+sbSHdlhzUsiJ4tBiVg2tNlg/XuJrZxUSbSv8hUCul/IEQ4g+Df387fsubGnyBYfacaA+JRRphkFYAi1AZWLOHNlYtYlWp4TDq9wUiJjRsrFpEui2Z9ZV57DriDPl7X2MP22oqRlWvhOfmamYfE20q/SSwJ/h4D/D5OK0rLsRSC2nGInfWtoaYj5uqC9n6cAlry3NGCZsRZlnCltXFypML8OL+Jg43uzjUZDQl3HOig5cOXOA7b59nU3Uhh5p6eenABQ419fK1R8s41NTLztpW0m3JAPgCQ0pQw81S63fRNZ6zg4k6hPKklJcBpJSXhRAL4rCmuBGLxzJaLNKeYcOReRd1re0cagrNHjIdOqa5anK42UVZboYSZn9gCIC61r6Q8rPw3+sr85Tj6IUn7sWeYRvlNLJ+F0B7YmcBU+atFUJsBbYC3HPPPVPymbF4LM1kg/HebxUWM31v68OlIdlFJ9vcHG528eL+JrZvqOT97gEACrPTWF+ZF2L22jNsSjvuOuIctb8Mv7FE+i56L3pnM1Hh7BVCLAxqzYXAx9EOvJWm0vEiUv7qrQTnre9/+WAzO2tb8QW1IUCaLSnk/Ds2V/HMz85wuNnF4PA56tuNXUBXv59DTb1KeC96Gnjz66sBQwhXFNtZtySX7Rsq1ZqsWnXXESebqgtDPmsiYRbNzGCioZR9wJbg4y3A2xM836Rz+w2kRzysG6vyWbckl41V+SFH2DNsLM3PAgwn08oSY3TM2nIHm6oL2b6hkmJHOk6Xl2ffeE/tSU2P777GHrWXNG8M5j41fL2Rvsd4e1G9V51ZTLSp9A+AvUKIrwIXgU2Tsch4crvB+S2ri1UurhnmWFXaS3a1LVSDBdMhz1y8wraaClaVOjAFuyw3kyerFrGztpW61j4qF87lhSfuZUWxHQB/4CY7a0P3ktFK2iJ9j/CKm3CtqrOGZhYTbSoNUBOntUwJY5VqjdU3yPo+04nkCwyx50S7Crlsqi6k6fI1wNhn+geNipbdR9sAQ8BB8OA9WZy5OABChOw5lxXMD9nDerwB5QFeVnBJhV9MD3H49wivuAFGmcLW35rEZtan71mJ5hE1teWKYjuvvNPC9g2Vqq+QteTMrAEty83A6fKy+2gba8ocADR0GPvPnbUtrC03nmu6NIDHGwgK+5D6LFPb7TnRzuFmFw/eM5+3G3vocPvUvhVGaz9rxQ1IfIHhkFI33Y9oZqGF00I0j6ipHYsd6XS4fQwOn2d5UTbbaspDsoesTpx9jT2YSQqm9gPYVlMR8tyzb5xheZEdECqLyPQIn2xzA+C+HqDT46PYkU5F3lyWFWRFzNOFEW2fbksJlrola4Gcocwa4Yw12dx6IY88NvaMj1TkUpLjo2JBJjtrW9hWUxEiGGbXAyBkz7djc5WKYy5dNI8X9zfxzGMVgBEbrWt1s62mPEQDm/HTstwM/uQ/VPLjY20M3ZTsPtrG2nIH/sBN0mzJgFRmNRAxw0gzM5k1wjkRZ4jVGWQmCDgy78IXGAoxg02BOts9EGJ6mgK650Q7bzdeCmrfmywvslOYnUZbn49HFy/gVND0Nfe0pzs91LW6+eu6dgDq2z2U5WZQ1+qmrtXQqg/eM581ZQ7WV+aRnT6iwa0m7ssHmzG1uLU2dawblq4HnX5mjXBORJOEa9TQbgopIY4YkDy62EiUMitWzAs83ZZCh9tHWW4GlQuz2FnbwroludS19nF5wI/T5QUMgX7u8cWjStPWLcnlmccq+OHBZnyBYc5cvMKZi1cA2Nd4CUNrC7VOjzfAs2+cUYK8/+ylkM8Y64alPbvTz6wRzslwhoSfM1ygTra58Q/eBEa0obFPNRLnHZm2YIikKSRDyKyI8Q/epGJBJoXZ6Thd13nmsQpeeaeFulY3a8sdPL2qiObea9xfkIXVvDX3mYaDyhDMOSlJOF3ekCyksW5Y2iyefmaNcE4VZgw0LTUJ/+BNstJSONbiwu0NsPtoG2W5GWysyqcsN1MJ9o7NVaolipGzOxSyj1xbnsNxpxtx8EPqWvsosqdT1+pGSsPUfezeBSoeCkLFRNdX5vHGby7S4fZxY+gm9ozUkCyksW5YEy321kwcLZwTIJKn1EzF+8ID+fxlbQvz5qQoQTJDLN95+zzLi+Yrp87GqkWAYZrurG1h6yOlrClzUJabyZzUJBo6jRLZ0pwMHq7IUYK+ND+LRxYbxd0v7m+ivt047lCTkRzx4v4mOtw+VpZk03v1EzrcvlFJ/GOhTdvpRQvnBLBm5JhOoHVLcoPZQw5qn/80Ttd1lQgPRllZbqYtRDOebHNT3+5hc3UBa8sdvN91hfp2D9XFdk539qt95ZzUJBXa2fpIKWnBv01tvbIkm5SkJNZX5qluDWvLc/jRUw+o9VpzdcN7H4U3M9Om7fSihXMCmBet2xsICoKD7RsqVfE1GCl7r33lIfWe177yEF/c9S4Ai7LmsKm6kGMthsPn385/xIDfSEZYtyQXkJbG1oYQHzj3EV39fors6XR6fMohZfXuGvWkRkjH7K5orYB56cAF9jZ0sfvparLTbSF75HAvs9aY04cWzhiJFFowL14jVAGlOZlKS461R1uSN5f6dg/rfyuP5x5fjD/oeR3wD7GmzEF1sT2Y6gcgONnWR317P2d7rqpzdHp8rCyxq32qIcjuEIePqc1fPeKkpfca2zdUsqm6kJ/VX8Tp8vJHb50l3ZaiNH74jUUzvWjhHIdIhdXh2sRsMn2sxRX0jjaFaMtwsjNSAWjru47TdR2AlSXZLLl7HtnpqWysWqRuBM89vpgznbl8dc8p+n2DZKWlYEtKwuUNcOmKn5//povdx9pYlDWHlSXZITeGHZurePWIk78/1clV/zCBoXO88qUHcWTa6PT4aLp8lWs3hinLzWD7hkrKcjNDEvmBcR9rR9HkoYVzHMx95baacrbVVOALDIWNZhjRoKZj5pnHKkbt66x7zy2rS5RWs6b2pdtSeP1d16gkhlfeaaHfN8ic1CQG/EM8eM98XN4AXf1+GjqNUMmlgRtcGrjBz0918Uef+y21rsaL/Vz1Gwn4124M8uwb76k97LUbw8xPT8Hp8rKv0UisD/cUR8o1tj7WZu/koYVzHEaS0gWhscTRJVlluZns2FylAv++wBDPPb4EwCKEhlY1wycriu0MDt+kcmEWX3yokFWlvayvzFO/Xz74IYXZ6SqvFyA1WbCyxE59uwfXNaM2M1nAsITzPQMhDp7LAzfU+kyz2Hquz1TeTafbx1tnuunq97OtpoJtNeX4AsPKixyt+4I2fycXLZzjYGb2WCtQopVkASGBf2u2juGtNbTqywc/xB8YJs2WxJEPP6au1c3DFbmjzEqzHQrAtppywPTs9vP0qnvovXqDDrdPxVSLHekszc9S79t11InHO8iclCRuDN1kZUk2q0pz2Fi1SCXm+wPD7G3vBqDIns6W1cWWEIok3Zai/g+R8441k4UWzgiEO3/C21RahyCFY20YNuLUGfHa7jriDGkKZgq8GeIINSsNT+uaModq2/n9f/mA+vZ+Dje76Or3A+AfvElhdhoF2WnsPtrGQ8XZZKWl4PEae9QB/xDFjnS+/4VllOUaMU5To3//Xz5Qa/mdpXeHfN9INyGdczt1aOGMQHjwPVpubSTsGTalfSJhCq+pOU0BNsMZkUcSGhrY4w3wfreRaNDV72d+WipX/IMAXPEZe9DC7DR+0zHS17vIns7Znqt0uH3sa7yk1mYKV5otSR332fvuVntlM55qtuqMxTGmiS9xEU4hRAdwDRgGhqJNTZopTDT4HmlminWIbni3P2v3PVNYw+syAU539qssoLXlDr71+BK+8bMzXBq4wbVPDKfPovlpSqMWO9K59+55lhCMVGs7+qHL6HckJWvLHdS1ulUvo5NtbpYVZLGztlW16jTjo9bStltFa91bI56ac52Usm/8wxKf2w2+Wx0xMNJS01pZcrZ7gB2bqyIOTQrPrT3Z5laZRb7AkEpIWFliZ3lRNsU5GZTmZnBp4AZZaSl8sbqQLz50D/saezjW0seZi1cYHL4ZfE82IFhRbKcsN4PjTjfHncbe+OlVRaQmJ4XUmJotU9ZX5vH9X35A48V+nv5UESDUencdcd7SiERrRpXpYAuf+K0ZQZu1cSTS3BNfYFil0UkpOdzs4s2Grohm8oh2qrCkAfZaWo8YFzTAztpWznYP8PtrS7k8cINv1lTw1ns9gBF3fbvxEgA9V25Q7EinyJHBztoWTrb14XR5WVPmYOjmzeD+9WO6+v1ULMhkWUEWywrmqw4Pu444VQ+kVtd1PN5B1aH+pQMXONbSR11rH77AcIhFECkl0LqXneyJ33cC8RJOCfxKCCGBXcEetSFMR1PpqcZqDlvjo6YG2tfYQ3WxPcQktJp6ZmXJxqpFbKxaxODwOdzXAyquapak7TnRwdpyR7AHrg+ny8uf7jtPv2+Q9r5TPFmVT4fbR3Z6Kv2+QTrcPvwBw+wdHJZsq6nArDv9g394X9V4/qqplw63jxeeuBcwNOOKYjsrS+z09PvpvuJXHe3Nwm739U+CGj20FbH5/cNTAkduNHA7E79nE/ESzjVSykvBcQwHhRAXgrNVFNPRVHqqCe/QZ/42NZB1D2cyYuoNc7b7itKWgKXjgcSReVdIeGVliZ215TkhDcUAJYhmYfa39jbS4faRkmw4lW4MDvPaiTau+oc53dmP0+Wl2JGOPcPGmYtXWFvuCBkPsW5JLvXtHrY+XErthV6cLi8/P9Wl0gGz021qbVaspnp4SuBYXfY1I8RFOKWUl4K/PxZC/BPwEHB07Hfd2YTvW6M5mUZMvaFRIxlMk9HUaAbGfa2+3cO2mnIershhRbGdHx78kEVZd3Fp4BNuDBrnKsxOI2/eXeTNm8O3P3vvqHGFpbmZXB64gdPlZfimcd7BYcm+xh5lilcsMBqKAThdXtaW5/Cr8x8F12MkVEQyS631oOZeWjuCbo0JC6cQIgNIklJeCz7+DPBnE17ZHUY0J1N4yxPrhfyjpx4YNRm73xfgdGc/i7LSON15he89uZRDTb3UtfaplieQA8A/n71Mv2+QRVlz+NbeRr77H5eSO/cu/vX8ZdYtzqPxogeny8uirDnKw1vf7iE1WSjTd2dtq6otXVPmoHLhPKWtt2+ojLi3BELCLtHaeYZ7sbXghhIPzZkH/JMQwjzfz6SU/xqH884azIvUHxjmO2+fU6l/6baUkEoRc9pYXatbmbJmLm9ZbgZbPlXMsoIs/IM3GRy+SX27h7TUJC4FU/ie29vI0kXzuOof5v3uK0obC2EkQ/R7P6H2gpGxZKYTvt14KcSzW12cHcxWMszkZ352huNON0c/dKljTG/z1odL1H47UrWLtcsgjCRGaAwmLJxSyjbg/jisZdYSfpEatZxizK7tZoPrZx6rUE6d7/7zeQqy06hrdVOYnQYY2UOVd2fS7vbR7xukrtVNkT3dYiaDI8PGo4tzee7njfRcMQT5385/RJotWR2XlZbCk/fns2V1idonn+2+ogTS7Mpg7osB0oL5x9FMW6OWtS+o7UdSHTUGOpSSAIzOGjIagEVKEbSawatKHRz50IXT5cWekUqH26e6+5kOIoC5aan4B28q7+3vLL2bNFuSytM923OVP/iH9+n0GII4b04KnR4fhy/0sihrDld8AQb8Q5y7NBCS3re+Mo+KBRdpunyN1aUO9rzbwfrKvGAihbGHffXXrew+1m4ppxu52dgzbPzoqQdCStE0I2jhTACieS/Hiv9FDtUY7TEfXbyAHb9qxj9oFHHfX5BNanIyda19rCyxKyX1/S8s4+e/MYTr99eW8ONjbZQtyKT5o2vUt3tCirsBXNc+Ues1bxBNl69R19pHd78v2I/3HA9XGJrfGD1h7H9Lc4ycXmtoyHquSMz2jCItnDOU8FANoAR81xEnx51u1pbnqFace060U9faxwcfDai5oW+d6eGJ315IXWsfD1fk8MqXHuT5vY3Ut3uUlk0VMCgNea5YkMl//ut6vvfkUspyM9VsGID7C+aRnCQozcnkpQMXWFvuYOvDpQAsL5oPCOpOGqEhR2Zsc1MjpUHOJmGd6HxOzSRjnalpfWxqnEgX6abqQuW5Pds9QL/PCPoXZqepwmswkucP/Ptl1pYb4RjTM1yWm0G/b5AiezqOuXMAw0h9p9lFXWsfL+5vwuMN4AsMGZoYONrixunykp1hU7m6TZcH2H2sjbPdA2ysWqS6Cq6vzBs1KzTSvNFN1YUheby3P1t1ZqI1Z4ITbfLZSEJ9e0iLzUNNRpH2soL5DA7fDE7ZPk9dax/L8ufR1e9nwVwb89NT6ezz0tXvp6vfT2qyCOkl9MI/nlXVLWYhd0ZqEvcVzGf7hspgMkQrK0uMkRJd/X6y0lLwDw5TmpNJXaubzLtSSEtNUokVjgwbx51uXtzfxLKC+coJtqm6ELc3oATXJNZY8Z2KFs4Ex+zEEKkzgeHlHWkpYo4JNBtJry13sK2mHH9gmLrWPuakGjmxH18LMHdOKsFCFlaW2EeFbIKhMVUPCuAdvKnyak93GqZxfbuHB++ZT1e/nwH/kBq0BHCy3eh4n5aapFL+zJjnsoIsNTh406snlAPL2lc33Iydbd0AtXAmONZODOHj/NZX5nGsxUVpTiZtfV61/zM9tnWtbpYX2UmzpbCtppyNVfmqA8Kji3NVat+qUrvKlTUxOwT+TuXdzElN5nDzx8xPMzr1BYbOcdzp5q4UwSdDkqHhm2yuLuBdp5tPlTmYk5qMlPB7Kwr5k33nuOIb4qs/PcXffHmFas9iCtxXXvsNTpeXIns6n39gUVjt6Ejh+WwSShMtnDOAcHPObBZWsWCuanHyzccXq/kqaalJPLp4AT882Myxlo85c3GAbTXllOVmhgT63/pva1Qfoy/8v+N0uH24rwdwZNowQyFGJUogmEFkxE6vfTIUolHP9lzFGximq9+PLZivC1BdbGdz9T3sPtpGh9vHi/tHp/tt31DJ4PB5KhfOVVlC1uoca1M1mF1d/7RwzgDCzTmzWdjg8E3lMDHCMSOCt+uI09LLCEBEHbT7fFCDArzfbXSbN5MYzE599oxUPlXqoOt0N2e7B0LWt7LEzrc/ey87ftVM2YJMstJSOXPxCv7AUHCGqFH4bdamWinLzeThihxeOnBBJdD7AkNsq6kI6Wdk9jKaTV3/tHDOQMxmYWavWdPzGZ594wsMKU0a2rjLwHxsaq/SnHSae68BqJYnXf1+5qUl4/EOsnD+HNVpvjA7jSfuWxjiiCpbkMnr73ayLH9e8BOE6ulr5t4OFA+qzKZTHZ6QUjkzi8havRPJCaQdQpqEJXzEQ6SBQ+GaFCJ7O02B/tvfX8nLBz+kvv0iK0vs3F+QxfvdA3T1+7nqN0rQNlbl0+8b5Eizi5e/WMWDRdk4XdfZ+nqD2jcCIVrYXJNpqtozUvF4B2nv89Lh9qkcYtOjG75Gq9UwUgc6O9BxzjuA8HhgNKyxUetFPxJvNPaZq0rtpNmSqW/3sLLEzraacrZvqOTF/U28/m4nnR4fpzo8eLwBvvrTU0Z8Mz2V721cSlluBldvGHvR+nYPz77xHh5vgE3Vhawpc6iOgKbX1z94k/WVeaxbkhsSRjGxxkN1nFMz45hIiMHasWD7hkrVLNtMXl9V6mDL6mKVoLC2PIflRfNVQrupJft9g+x5t4PdT1ez53gHv3i/m6t+I4Sz50QHzz2+mOpiO8edbgb8Qwz4DfO56dIA+xqTguGVnlHd7sNbv1h/x4tEzTzSwnmHMtYFZ63BNLN8Dje78AX+nVWlRoxyY9UiznZfYWPVIvacaFdT1H701AOqd68vMBxM0ZO8HxSqigUXOe7s46rfGPVwxTeEqZHNzoJmy5T3u/upa3UHW6eU4x80kiaKHekhA4BPtrlZUWy/7cLt8YQvUeeQauG8Q4l2wVm7AZoJAUZnPsMMNfNuzZYp0ERu5l2AUVr27BvvUblwLmm2FHbWtvDCE/cGJ619SH27R3VtKHak80hFLtkZqarKxtrTd31lHk2XB9TnpiYnUblwLmDsWf/0F8bQpUNNvUqTRhpTOJH/hUmiZh5p4bxDiXbBjQzateMLDLH1kVKQUN/ez8oSe1BzSpVja2oygKMtfcGa0D7WlDlUw2kwtKI5BLgwO41HKnJ4/WQnWx8uCdFaVjParC3t9PgsbT+zqW/v57jTHRL2sc6PMX+He6hv9X9hkqiZR/FqKv1ZYCeQDPy1lPIH8Tiv5vaxlnVZL2LzAjU6F3hIt6WwY3MVjsyR1/ac6FAZRYeaelmSN5f/8S9NfLOmgr+tv0hPv5/jTjfVFlPzUFOv6pHb1e+nrc/Yi755upt+n9GV3pzEdrLNzZZPFQPwzGMVHPnwY053XqGutY9tNeXcX5jN+Z4BZcqa7VlOtrlZkjeXk21uevp9vH7yYsiwqPH+FzONePQQSgb+L/A40A2cEkLsk1I2TfTcmlBidVxYj4tm0j3/mSXY3mlR8zytnlvTXC3LzaTs0Ux2HXHidHm5NHCDx+5dwEsHLqjJ2y8duKByefPnGxUsK0vsfO/JpSEhFjPLJ9xMXVXq4LnHl4xa83GnG5ulIZlpzr7ffQWPd5C0VDPQIMb9v4S/nqgOoHDioTkfAlqD7UoQQvw98CSghTPOxOq4GMvDab72whP3RhzwO1bQ39op0MzP/cV7l5THtufKDdaUOVian8W+xh7+4j/dzyvvtJA79y521rbiDwzz9U8b09KsfYWss0vtGTalXZ95rIJVpQ6VqABNfOGBfP77P57FP3jT6JsUllwR6f8S/nqiOoDCiYdw5gPWwFM3sDL8oNnQVHqyidVxEV6IHansKtqeLZIJGP6cmYh/tvuKam0CxvyW5UV2VQpmjp549o33AGi6fC2kZab52aaD6qKngTe/vlpp12UF81U81Ey8ePlgsxp3uPvp6qhZRGP93xLVARROPJIQInVmGtU0Wkq5W0pZLaWszs3NjcPHzj7GKrCO9TjztUNNvaMC+uEF0NGeM5Metm+oVG1I1pY7+N6T99Hv/YTC7DQeKMzicLOLPSfaqVw0j2X58+jo83Kms589J9p56cAF9pxoBwimIRp9j57f28j6yrxg13mp1ujxBvj+Lz/grTPGyIknq/LVOMPx/i/hr0dOwEg84qE5uwHrLagAuBSH82puk1hinCuK7aOyciKZe9FSAzHjdpQAAAmtSURBVM3H1gZde0608/rJiwAUOTLUoOHdR9tU25Nv7W3kyar84CcKtQ/d/XS1Sui3zocx83LNaeGAMmdv5XtHItHN23gI5ymgQghRAvQAvwd8KQ7n1dwmY1105mvWQUlmcfN4+81IDaRDTV7DiCrMTqNy4Tz13nRbMs6Pr7H3dA+PVOSwZXUxlwf8/PREO4cvfMzZHiPeaa31hJGu8c/vbVSCWZidpsxZiM35FY1EN2/j0bd2SAjxDeDfMEIpP5FSnp/wyjS3zVgX3XgzTML3g9aQTHjyAoQKwZbVxaTbklWRdJotOaS9Z9mCueq8hz7oZcA/xNmeARVzNc/ndF3n+b2NPPNYhRohYQ4K/sKD+aoS582GLtzeALuPtuELDCttGmsMNNFDLPGalfJL4JfxOJdm4ow3edt8zdSYJuGzRK3neDU44HdlSXZIS5NI5zbNUV9gSM3jNHN2wdjnfflTRbx8qBUJ9F69oUYa7thcxXfeNnoemZUr5n60LDeDjUGT2NSSa8ocwU+Xo0Yp+gLD6uaQyCGTaOgMIY3CzB6yDlMyOR80PVOSklT800r4fs8qpG5vgJ21I8Jilo5JIC01SQmgObu0cuFc6lr7eKQih6ceMvJszf2o2WPIWq9aXZzNxqp8pS2tIZ9E3lOOhxZOjSJSL1yT5z+zhI/+4X2e/0zkbBxTkx1r6WN50XyVXWQ6igxG5nEuyZvLd//5PA8V21k4P00VbJuvW8cKvtnQpbS11WRNt6Wws9aI2ZreZxiZA+p0Xeds90DEUrSZgBZOjWIsc/hUhzGR7FSHhweLske9vqm6UO1FzX65pnlsdkMIdyA91XtNJUSY2jhcA5sm6rEWF8uLstnXeCmkpab1ty8wrDKR7Bk2FS+1Or1iIVEyiLRwamLCMCNDL34r9gwbOzZXqZF+G6vyWVbQgy9YHmaauS8f/BBz5J/ZPdA6osGaGL9jc1WY0LvV+IlICRamyWyOsl9RbKcsN4MVxfZb+q6JEmLRwqmJCaNFZ+jFH+kYcySExxsI0Z5WL67xdwq+wLCa3m2OaLAKo6m9lhVkUZidjtN1nY2W5AOPN8CrR5yc7xngzz5/3yhN+so7LThdXl55p4XXvvJQzDm2492IpgotnJqYMZ0wbm+Alw82Rx14a/X6ri3PoaHDw3GnW7W6NPee5l50TZkjJLZpjXeaDb/WLcnluNMd0nT6zYYudh9tA+BPf3FOjSA012RthBbJEx1NQ8ZyI5oKtHBqYsZscL2z1rigo124Vq/vsoIsJVxbVheHCHP4XtT6OeZ5I9VzWr2ybm+A8z0DLM3PGiVoZbmZStB9gaFRnujwBAvrlO1ESFAQUo5Kg510qqurZUNDw5R/rmbixDIqPlp/3PHMw1gcMS8fbGZnbWuwZcqD6vyRxt5bs4a21VSMGfM0HU+A6u4wFQghTkspqyO9pjWn5paINks0/BjrxT3ehT4yfmFYeWKjvcc/aBR017W6eX5vIxULMtl9rD2k6NoqaGOFh6yY+0xruGe60cKpmXbCBwGPJRxmkXWRPT3Y9d60/EaKoyKVzEVqvG0llpvOVKOFU3PbRDJDbydGGE27RTrXxqp8znYPqI7x4eZstM+fjEG8kx0P1cKpuW1iLTEb7yKOlvwQSaBGysocEXOEo3lgw4vMYzGhx1v7ZMdDtXBqbptYWprA2BfxWBd/pBYr0XJ/zXP5AsNsqymPmpQ/MsFsfBN6vLVPukdXSjnlP8uXL5ea2YP7+ify1V+3Svf1T0a99uqvW2XRt/fLL/+kXr0e7fixzuO+/on88k/qZdG398tXf916W2uJx/G3CtAgo8iJnpWimXTGaiOyqdqYoWK2NIERbRU+E2Ws84ynVaMRqQ1LrJ852WizVjOt2DNsaoaK6XG9HXMx1pDJjOrEF02lxvIDfBejNUlj8OdzsbxPm7Uzm4mYepHeO9mm41ifP5WfHQkm2ax9WUpZFfzR3RBmARMZxRfpvVNpOkbrxBepEdp0d+bTZq3mlpmIlzIRclZjYaxBUFNV6xkP4fyGEOJpoAF4XkrZH4dzahKYiTTGSvSmWiaRbiLWKd4w+XvUcc1aIcQhIcS5CD9PAn8FlAFVwGVgxxjn2SqEaBBCNLhcrrh9Ac2dz3gm5mSYoJHM3Rf3N6lGY1Oh+ccVTinleinlfRF+3pZS9koph6WUN4EfY8xNiXYe3fFdc1uMt8eNdQ88USHfvqGSdUtyQ/rmTiYTMmuFEAullJeDf/4ucG7iS9JoQrnVWSjRuNWBR+GY81qmionuOf+XEKIKYzZKB/C1Ca9IowljvH1qrPvYeAj5VDqEJhRKkVL+Fynlb0spl0kpN1q0qEYzLYxlmlr3kZGOiyWkYzWhJzvcokMpmjuK25lheiteV6t21VUpGs0tEOv+83bjrZH6G02W51b3ENJoppGxegjpqhSNZoJM1t5TC6fmjmeyHTeR4qzx+Ey959Tc8Uy24ya8/22snQTHQwun5o5nsh03VifRrbZBGQstnJo7nqlMto+16DsWtHBqNHEknjcC7RDSaBIULZwaTYKihVOjSVCmJUNICOECOqf8g8cmB+ib7kVMMbPxO0Nife8iKWXEAudpEc5ERAjREC2N6k5lNn5nmDnfW5u1Gk2CooVTo0lQtHCOsHu6FzANzMbvDDPke+s9p0aToGjNqdEkKFo4LQgh/kIIcUEIcVYI8U9CiPnTvabJQgjxWSFEsxCiVQjxh9O9nslGCFEohDgshPhACHFeCLFtutc0HtqstSCE+AzwjpRySAjxPwGklN+e5mXFHSFEMvAh8DjQDZwCnpJSNk3rwiYRIcRCYKGU8owQYi5wGvh8In9nrTktSCl/JaUcCv55EiiYzvVMIg8BrVLKNillAPh74MlpXtOkIqW8LKU8E3x8DfgAyJ/eVY2NFs7o/FfgwHQvYpLIB6zt0btJ8As1ngghioEHgPrpXcnYzLqSMSHEIeDuCC/9sZTy7eAxfwwMAX83lWubQkSE52bF/kYIkQn8I/BNKeXV6V7PWMw64ZRSrh/rdSHEFmADUCPv3A15N2At0S8ALk3TWqYMIUQqhmD+nZTyrelez3hoh5AFIcRngR8Cj0op79hRaEKIFAyHUA3GZPJTwJeklOendWGTiBBCAHsAj5Tym9O9nljQwmlBCNEK3AW4g0+dlFJ+fRqXNGkIIT4H/CWQDPxESvnn07ykSUUIsRY4Bvw7cDP49B8l8jR2LZwaTYKivbUaTYKihVOjSVC0cGo0CYoWTo0mQdHCqdEkKFo4NZoERQunRpOgaOHUaBKU/w9o7aDnRn8cDgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def set_figsize(figsize=(3.5, 2.5)):\n",
    "    plt.rcParams['figure.figsize'] = figsize\n",
    "\n",
    "set_figsize()\n",
    "plt.scatter(features[:, 1], labels, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.2 读取数据\n",
    "\n",
    "在训练模型的时候，我们需要遍历数据集并不断读取小批量数据样本。这里我们定义一个函数：它每次返回`batch_size`（批量大小）个随机样本的特征和标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def data_iter(batch_size, features, labels):\n",
    "    # features = np.array(features)\n",
    "    # labels = np.array(labels)\n",
    "    num_examples = len(features)\n",
    "    indices = list(range(num_examples))\n",
    "    # 样本读取顺序随机，会遍历所有样本\n",
    "    random.shuffle(indices)\n",
    "    # 每次取 batch_size 的数据\n",
    "    for i in range(0, num_examples, batch_size):\n",
    "        j = indices[i: min(i+batch_size, num_examples)]\n",
    "        # axis表示维度，indices表示在axis维度上要取数据的索引\n",
    "        # tf.gather（在某一维度指定index）\n",
    "        # yield features[j], labels[j]\n",
    "        yield tf.gather(features, axis=0, indices=j), tf.gather(labels, axis=0, indices=j)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们读取第一个小批量数据样本并打印。每个批量的特征形状为(10, 2)，分别对应批量大小和输入个数；标签形状为批量大小。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 0.04718596 -1.5959413 ]\n",
      " [ 0.3889716  -1.5288432 ]\n",
      " [-1.8489572   1.66422   ]\n",
      " [-1.3978077  -0.85818154]\n",
      " [-0.36940867 -0.619267  ]\n",
      " [-0.15660426  1.1231796 ]\n",
      " [ 0.89411694  1.5499148 ]\n",
      " [ 1.9971682  -0.56981105]\n",
      " [-2.1852891   0.18805206]\n",
      " [ 1.3222371  -1.0301086 ]], shape=(10, 2), dtype=float32) tf.Tensor(\n",
      "[ 9.738684   10.164594   -5.15065     4.3305573   5.568048    0.06494669\n",
      "  0.7251317  10.128626   -0.8036391  10.343082  ], shape=(10,), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "batch_size = 10\n",
    "\n",
    "for X, y in data_iter(batch_size, features, labels):\n",
    "    print(X, y)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.3 初始化模型参数\n",
    "\n",
    "我们将权重初始化成均值为0、标准差为0.01的正态随机数，偏差则初始化成0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = tf.Variable(tf.random.normal((num_inputs, 1), stddev=0.01))\n",
    "b = tf.Variable(tf.zeros((1,)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.4 定义模型\n",
    "\n",
    "下面是线性回归的矢量计算表达式的实现。我们使用`matmul`函数做矩阵乘法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linreg(X, w, b):\n",
    "    return tf.matmul(X, w) + b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.5 定义损失函数\n",
    "\n",
    "我们使用上一节描述的平方损失来定义线性回归的损失函数。在实现中，我们需要把真实值`y`变形成预测值`y_hat`的形状(因为 `y_hat` 计算后 `shape= (10,1)`, 而真实值 `y` 的 `shape = (10,)`,即单行向量)。以下函数返回的结果也将和`y_hat`的形状相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def squared_loss(y_hat, y):\n",
    "    return (y_hat - tf.reshape(y, y_hat.shape)) ** 2 / 2\n",
    "    #return (y_hat - y) ** 2 / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.6 定义优化算法\n",
    "\n",
    "以下的`sgd`函数实现了上一节中介绍的小批量随机梯度下降算法。它通过不断迭代模型参数来优化损失函数。这里自动求梯度模块计算得来的梯度是一个批量样本的梯度和。我们将它除以批量大小来得到平均值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sgd(params, lr, batch_size, grads):\n",
    "    \"\"\"\n",
    "    Mini-batch stochastic gradient descent.\n",
    "    lr: 步长\n",
    "    \"\"\"\n",
    "    # 对每一个参数求梯度，并更新\n",
    "    for i, param in enumerate(params):\n",
    "        param.assign_sub(lr * grads[i] / batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2.7 训练模型\n",
    "\n",
    "在训练中，我们将多次迭代模型参数。在每次迭代中，我们根据当前读取的小批量数据样本（特征`X`和标签`y`），通过调用反向函数`t.gradients`计算小批量随机梯度，并调用优化算法`sgd`迭代模型参数。由于我们之前设批量大小`batch_size`为 10，每个小批量的损失`l`的形状为 (10, 1)。回忆一下自动求梯度一节。由于变量`l`并不是一个标量，所以我们可以调用`reduce_sum()`将其求和得到一个标量，再运行`t.gradients`得到该变量有关模型参数的梯度。注意在每次更新完参数后不要忘了将参数的梯度清零。\n",
    "\n",
    "在一个迭代周期（epoch）中，我们将完整遍历一遍`data_iter`函数，并对训练数据集中所有样本都使用一次（假设样本数能够被批量大小整除）。这里的迭代周期个数`num_epochs`和学习率`lr`都是超参数，分别设 3 和 0.03。在实践中，大多超参数都需要通过反复试错来不断调节。虽然迭代周期数设得越大模型可能越有效，但是训练时间可能过长。而有关学习率对模型的影响，我们会在后面 “优化算法” 一章中详细介绍。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.000051\n",
      "epoch 2, loss 0.000051\n",
      "epoch 3, loss 0.000051\n"
     ]
    }
   ],
   "source": [
    "lr = 0.03\n",
    "num_epochs = 3\n",
    "net = linreg\n",
    "loss = squared_loss\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter(batch_size, features, labels):\n",
    "        with tf.GradientTape() as t:\n",
    "            t.watch([w, b])\n",
    "            #print(net(X, w, b))\n",
    "            #print(y)\n",
    "            l = loss(net(X, w, b), y)\n",
    "        # 求 w ，b 梯度\n",
    "        grads = t.gradient(l, [w,b])\n",
    "        # 梯度下降, 更新参数\n",
    "        sgd([w, b], lr, batch_size, grads)       \n",
    "    train_l = loss(net(features, w, b), labels)\n",
    "    print('epoch %d, loss %f' % (epoch + 1, tf.reduce_mean(train_l)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, -3.4] \n",
      " <tf.Variable 'Variable:0' shape=(2, 1) dtype=float32, numpy=\n",
      "array([[ 2.00006  ],\n",
      "       [-3.3998876]], dtype=float32)>\n",
      "4.2 \n",
      " <tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([4.200227], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "print(true_w, '\\n', w)\n",
    "print(true_b, '\\n', b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "- 可以看出，仅使用`Variables`和`GradientTape`模块就可以很容易地实现一个模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
